#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)

import os
import json
import time
import math
import torch
from torch import nn
from enum import Enum
from dataclasses import dataclass
from funasr.register import tables
from typing import List, Tuple, Dict, Any, Optional

from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank


class VadStateMachine(Enum):
	kVadInStateStartPointNotDetected = 1
	kVadInStateInSpeechSegment = 2
	kVadInStateEndPointDetected = 3


class FrameState(Enum):
	kFrameStateInvalid = -1
	kFrameStateSpeech = 1
	kFrameStateSil = 0


# final voice/unvoice state per frame
class AudioChangeState(Enum):
	kChangeStateSpeech2Speech = 0
	kChangeStateSpeech2Sil = 1
	kChangeStateSil2Sil = 2
	kChangeStateSil2Speech = 3
	kChangeStateNoBegin = 4
	kChangeStateInvalid = 5


class VadDetectMode(Enum):
	kVadSingleUtteranceDetectMode = 0
	kVadMutipleUtteranceDetectMode = 1


class VADXOptions:
	"""
	Author: Speech Lab of DAMO Academy, Alibaba Group
	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
	https://arxiv.org/abs/1803.05030
	"""
	
	def __init__(
		self,
		sample_rate: int = 16000,
		detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
		snr_mode: int = 0,
		max_end_silence_time: int = 800,
		max_start_silence_time: int = 3000,
		do_start_point_detection: bool = True,
		do_end_point_detection: bool = True,
		window_size_ms: int = 200,
		sil_to_speech_time_thres: int = 150,
		speech_to_sil_time_thres: int = 150,
		speech_2_noise_ratio: float = 1.0,
		do_extend: int = 1,
		lookback_time_start_point: int = 200,
		lookahead_time_end_point: int = 100,
		max_single_segment_time: int = 60000,
		nn_eval_block_size: int = 8,
		dcd_block_size: int = 4,
		snr_thres: int = -100.0,
		noise_frame_num_used_for_snr: int = 100,
		decibel_thres: int = -100.0,
		speech_noise_thres: float = 0.6,
		fe_prior_thres: float = 1e-4,
		silence_pdf_num: int = 1,
		sil_pdf_ids: List[int] = [0],
		speech_noise_thresh_low: float = -0.1,
		speech_noise_thresh_high: float = 0.3,
		output_frame_probs: bool = False,
		frame_in_ms: int = 10,
		frame_length_ms: int = 25,
		**kwargs,
	):
		self.sample_rate = sample_rate
		self.detect_mode = detect_mode
		self.snr_mode = snr_mode
		self.max_end_silence_time = max_end_silence_time
		self.max_start_silence_time = max_start_silence_time
		self.do_start_point_detection = do_start_point_detection
		self.do_end_point_detection = do_end_point_detection
		self.window_size_ms = window_size_ms
		self.sil_to_speech_time_thres = sil_to_speech_time_thres
		self.speech_to_sil_time_thres = speech_to_sil_time_thres
		self.speech_2_noise_ratio = speech_2_noise_ratio
		self.do_extend = do_extend
		self.lookback_time_start_point = lookback_time_start_point
		self.lookahead_time_end_point = lookahead_time_end_point
		self.max_single_segment_time = max_single_segment_time
		self.nn_eval_block_size = nn_eval_block_size
		self.dcd_block_size = dcd_block_size
		self.snr_thres = snr_thres
		self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
		self.decibel_thres = decibel_thres
		self.speech_noise_thres = speech_noise_thres
		self.fe_prior_thres = fe_prior_thres
		self.silence_pdf_num = silence_pdf_num
		self.sil_pdf_ids = sil_pdf_ids
		self.speech_noise_thresh_low = speech_noise_thresh_low
		self.speech_noise_thresh_high = speech_noise_thresh_high
		self.output_frame_probs = output_frame_probs
		self.frame_in_ms = frame_in_ms
		self.frame_length_ms = frame_length_ms


class E2EVadSpeechBufWithDoa(object):
	"""
	Author: Speech Lab of DAMO Academy, Alibaba Group
	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
	https://arxiv.org/abs/1803.05030
	"""
	
	def __init__(self):
		self.start_ms = 0
		self.end_ms = 0
		self.buffer = []
		self.contain_seg_start_point = False
		self.contain_seg_end_point = False
		self.doa = 0
	
	def Reset(self):
		self.start_ms = 0
		self.end_ms = 0
		self.buffer = []
		self.contain_seg_start_point = False
		self.contain_seg_end_point = False
		self.doa = 0


class E2EVadFrameProb(object):
	"""
	Author: Speech Lab of DAMO Academy, Alibaba Group
	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
	https://arxiv.org/abs/1803.05030
	"""
	
	def __init__(self):
		self.noise_prob = 0.0
		self.speech_prob = 0.0
		self.score = 0.0
		self.frame_id = 0
		self.frm_state = 0


class WindowDetector(object):
	"""
	Author: Speech Lab of DAMO Academy, Alibaba Group
	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
	https://arxiv.org/abs/1803.05030
	"""
	
	def __init__(self, window_size_ms: int,
	             sil_to_speech_time: int,
	             speech_to_sil_time: int,
	             frame_size_ms: int):
		self.window_size_ms = window_size_ms
		self.sil_to_speech_time = sil_to_speech_time
		self.speech_to_sil_time = speech_to_sil_time
		self.frame_size_ms = frame_size_ms
		
		self.win_size_frame = int(window_size_ms / frame_size_ms)
		self.win_sum = 0
		self.win_state = [0] * self.win_size_frame  # 初始化窗
		
		self.cur_win_pos = 0
		self.pre_frame_state = FrameState.kFrameStateSil
		self.cur_frame_state = FrameState.kFrameStateSil
		self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
		self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
		
		self.voice_last_frame_count = 0
		self.noise_last_frame_count = 0
		self.hydre_frame_count = 0
	
	def Reset(self) -> None:
		self.cur_win_pos = 0
		self.win_sum = 0
		self.win_state = [0] * self.win_size_frame
		self.pre_frame_state = FrameState.kFrameStateSil
		self.cur_frame_state = FrameState.kFrameStateSil
		self.voice_last_frame_count = 0
		self.noise_last_frame_count = 0
		self.hydre_frame_count = 0
	
	def GetWinSize(self) -> int:
		return int(self.win_size_frame)
	
	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
		cur_frame_state = FrameState.kFrameStateSil
		if frameState == FrameState.kFrameStateSpeech:
			cur_frame_state = 1
		elif frameState == FrameState.kFrameStateSil:
			cur_frame_state = 0
		else:
			return AudioChangeState.kChangeStateInvalid
		self.win_sum -= self.win_state[self.cur_win_pos]
		self.win_sum += cur_frame_state
		self.win_state[self.cur_win_pos] = cur_frame_state
		self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
		
		if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
			self.pre_frame_state = FrameState.kFrameStateSpeech
			return AudioChangeState.kChangeStateSil2Speech
		
		if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
			self.pre_frame_state = FrameState.kFrameStateSil
			return AudioChangeState.kChangeStateSpeech2Sil
		
		if self.pre_frame_state == FrameState.kFrameStateSil:
			return AudioChangeState.kChangeStateSil2Sil
		if self.pre_frame_state == FrameState.kFrameStateSpeech:
			return AudioChangeState.kChangeStateSpeech2Speech
		return AudioChangeState.kChangeStateInvalid
	
	def FrameSizeMs(self) -> int:
		return int(self.frame_size_ms)


class Stats(object):
	def __init__(self,
	             sil_pdf_ids,
	             max_end_sil_frame_cnt_thresh,
	             speech_noise_thres,
	             ):
		self.data_buf_start_frame = 0
		self.frm_cnt = 0
		self.latest_confirmed_speech_frame = 0
		self.lastest_confirmed_silence_frame = -1
		self.continous_silence_frame_count = 0
		self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
		self.confirmed_start_frame = -1
		self.confirmed_end_frame = -1
		self.number_end_time_detected = 0
		self.sil_frame = 0
		self.sil_pdf_ids = sil_pdf_ids
		self.noise_average_decibel = -100.0
		self.pre_end_silence_detected = False
		self.next_seg = True
		
		self.output_data_buf = []
		self.output_data_buf_offset = 0
		self.frame_probs = []
		self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
		self.speech_noise_thres = speech_noise_thres
		self.scores = None
		self.max_time_out = False
		self.decibel = []
		self.data_buf = None
		self.data_buf_all = None
		self.waveform = None
		self.last_drop_frames = 0


@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
	"""
	Author: Speech Lab of DAMO Academy, Alibaba Group
	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
	https://arxiv.org/abs/1803.05030
	"""
	
	def __init__(self,
	             encoder: str = None,
	             encoder_conf: Optional[Dict] = None,
	             vad_post_args: Dict[str, Any] = None,
	             **kwargs,
	             ):
		super().__init__()
		self.vad_opts = VADXOptions(**kwargs)
		
		encoder_class = tables.encoder_classes.get(encoder)
		encoder = encoder_class(**encoder_conf)
		self.encoder = encoder
		self.encoder_conf = encoder_conf
	
	def ResetDetection(self, cache: dict = {}):
		cache["stats"].continous_silence_frame_count = 0
		cache["stats"].latest_confirmed_speech_frame = 0
		cache["stats"].lastest_confirmed_silence_frame = -1
		cache["stats"].confirmed_start_frame = -1
		cache["stats"].confirmed_end_frame = -1
		cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
		cache["windows_detector"].Reset()
		cache["stats"].sil_frame = 0
		cache["stats"].frame_probs = []
		
		if cache["stats"].output_data_buf:
			assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
			drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
			real_drop_frames = drop_frames - cache["stats"].last_drop_frames
			cache["stats"].last_drop_frames = drop_frames
			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
				self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
			cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
			cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
	
	def ComputeDecibel(self, cache: dict = {}) -> None:
		frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
		frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
		if cache["stats"].data_buf_all is None:
			cache["stats"].data_buf_all = cache["stats"].waveform[
				0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
			cache["stats"].data_buf = cache["stats"].data_buf_all
		else:
			cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
		for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
			cache["stats"].decibel.append(
				10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
				                0.000001))
	
	def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
		scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
		assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
		self.vad_opts.nn_eval_block_size = scores.shape[1]
		cache["stats"].frm_cnt += scores.shape[1]  # count total frames
		if cache["stats"].scores is None:
			cache["stats"].scores = scores  # the first calculation
		else:
			cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
	
	def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
		while cache["stats"].data_buf_start_frame < frame_idx:
			if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
				cache["stats"].data_buf_start_frame += 1
				cache["stats"].data_buf = cache["stats"].data_buf_all[
				                          (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
					                          self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
	
	def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
		self.PopDataBufTillFrame(start_frm, cache=cache)
		expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
		if last_frm_is_end_point:
			extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
			                          self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
			expected_sample_number += int(extra_sample)
		if end_point_is_sent_end:
			expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
		if len(cache["stats"].data_buf) < expected_sample_number:
			print('error in calling pop data_buf\n')
		
		if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
			cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
			cache["stats"].output_data_buf[-1].Reset()
			cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
			cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
			cache["stats"].output_data_buf[-1].doa = 0
		cur_seg = cache["stats"].output_data_buf[-1]
		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
			print('warning\n')
		out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
		data_to_pop = 0
		if end_point_is_sent_end:
			data_to_pop = expected_sample_number
		else:
			data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
		if data_to_pop > len(cache["stats"].data_buf):
			print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
			data_to_pop = len(cache["stats"].data_buf)
			expected_sample_number = len(cache["stats"].data_buf)
		
		cur_seg.doa = 0
		for sample_cpy_out in range(0, data_to_pop):
			# cur_seg.buffer[out_pos ++] = data_buf_.back();
			out_pos += 1
		for sample_cpy_out in range(data_to_pop, expected_sample_number):
			# cur_seg.buffer[out_pos++] = data_buf_.back()
			out_pos += 1
		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
			print('Something wrong with the VAD algorithm\n')
		cache["stats"].data_buf_start_frame += frm_cnt
		cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
		if first_frm_is_start_point:
			cur_seg.contain_seg_start_point = True
		if last_frm_is_end_point:
			cur_seg.contain_seg_end_point = True
	
	def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
		cache["stats"].lastest_confirmed_silence_frame = valid_frame
		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
			self.PopDataBufTillFrame(valid_frame, cache=cache)
	
	# silence_detected_callback_
	# pass
	
	def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
		cache["stats"].latest_confirmed_speech_frame = valid_frame
		self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
	
	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
		if self.vad_opts.do_start_point_detection:
			pass
		if cache["stats"].confirmed_start_frame != -1:
			print('not reset vad properly\n')
		else:
			cache["stats"].confirmed_start_frame = start_frame
		
		if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
			self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
	
	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
		for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
			self.OnVoiceDetected(t, cache=cache)
		if self.vad_opts.do_end_point_detection:
			pass
		if cache["stats"].confirmed_end_frame != -1:
			print('not reset vad properly\n')
		else:
			cache["stats"].confirmed_end_frame = end_frame
		if not fake_result:
			cache["stats"].sil_frame = 0
			self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
		cache["stats"].number_end_time_detected += 1
	
	def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
		if is_final_frame:
			self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
			cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
	
	def GetLatency(self, cache: dict = {}) -> int:
		return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
	
	def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
		vad_latency = cache["windows_detector"].GetWinSize()
		if self.vad_opts.do_extend:
			vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
		return vad_latency
	
	def GetFrameState(self, t: int, cache: dict = {}):
		frame_state = FrameState.kFrameStateInvalid
		cur_decibel = cache["stats"].decibel[t]
		cur_snr = cur_decibel - cache["stats"].noise_average_decibel
		# for each frame, calc log posterior probability of each state
		if cur_decibel < self.vad_opts.decibel_thres:
			frame_state = FrameState.kFrameStateSil
			self.DetectOneFrame(frame_state, t, False, cache=cache)
			return frame_state
		
		sum_score = 0.0
		noise_prob = 0.0
		assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
		if len(cache["stats"].sil_pdf_ids) > 0:
			assert len(cache["stats"].scores) == 1  # 只支持batch_size = 1的测试
			sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
			sum_score = sum(sil_pdf_scores)
			noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
			total_score = 1.0
			sum_score = total_score - sum_score
		speech_prob = math.log(sum_score)
		if self.vad_opts.output_frame_probs:
			frame_prob = E2EVadFrameProb()
			frame_prob.noise_prob = noise_prob
			frame_prob.speech_prob = speech_prob
			frame_prob.score = sum_score
			frame_prob.frame_id = t
			cache["stats"].frame_probs.append(frame_prob)
		if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
			if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
				frame_state = FrameState.kFrameStateSpeech
			else:
				frame_state = FrameState.kFrameStateSil
		else:
			frame_state = FrameState.kFrameStateSil
			if cache["stats"].noise_average_decibel < -99.9:
				cache["stats"].noise_average_decibel = cur_decibel
			else:
				cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
					self.vad_opts.noise_frame_num_used_for_snr
					- 1)) / self.vad_opts.noise_frame_num_used_for_snr
		
		return frame_state
	
	def forward(self, feats: torch.Tensor,
	            waveform: torch.tensor,
	            cache: dict = {},
	            is_final: bool = False,
	            **kwargs,
	            ):
		# if len(cache) == 0:
		#     self.AllResetDetection()
		# self.waveform = waveform  # compute decibel for each frame
		cache["stats"].waveform = waveform
		is_streaming_input = kwargs.get("is_streaming_input", True)
		self.ComputeDecibel(cache=cache)
		self.ComputeScores(feats, cache=cache)
		if not is_final:
			self.DetectCommonFrames(cache=cache)
		else:
			self.DetectLastFrames(cache=cache)
		segments = []
		for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
			segment_batch = []
			if len(cache["stats"].output_data_buf) > 0:
				for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
					if is_streaming_input: # in this case, return [beg, -1], [], [-1, end], [beg, end]
						if not cache["stats"].output_data_buf[i].contain_seg_start_point:
							continue
						if not cache["stats"].next_seg and not cache["stats"].output_data_buf[i].contain_seg_end_point:
							continue
						start_ms = cache["stats"].output_data_buf[i].start_ms if cache["stats"].next_seg else -1
						if cache["stats"].output_data_buf[i].contain_seg_end_point:
							end_ms = cache["stats"].output_data_buf[i].end_ms
							cache["stats"].next_seg = True
							cache["stats"].output_data_buf_offset += 1
						else:
							end_ms = -1
							cache["stats"].next_seg = False
						segment = [start_ms, end_ms]
						
					else: # in this case, return [beg, end]
						
						if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
						cache["stats"].output_data_buf[
							i].contain_seg_end_point):
							continue
						segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
						cache["stats"].output_data_buf_offset += 1  # need update this parameter
					
					segment_batch.append(segment)
					
			if segment_batch:
				segments.append(segment_batch)
		# if is_final:
		#     # reset class variables and clear the dict for the next query
		#     self.AllResetDetection()
		return segments
	
	def init_cache(self, cache: dict = {}, **kwargs):
		
		cache["frontend"] = {}
		cache["prev_samples"] = torch.empty(0)
		cache["encoder"] = {}

		if kwargs.get("max_end_silence_time") is not None:
			# update the max_end_silence_time
			self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")

		windows_detector = WindowDetector(self.vad_opts.window_size_ms,
		                                  self.vad_opts.sil_to_speech_time_thres,
		                                  self.vad_opts.speech_to_sil_time_thres,
		                                  self.vad_opts.frame_in_ms)
		windows_detector.Reset()
		
		stats = Stats(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
		              max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
		              speech_noise_thres=self.vad_opts.speech_noise_thres
		              )
		cache["windows_detector"] = windows_detector
		cache["stats"] = stats
		return cache
	
	def inference(self,
	              data_in,
	              data_lengths=None,
	              key: list = None,
	              tokenizer=None,
	              frontend=None,
	              cache: dict = {},
	              **kwargs,
	              ):
		
		if len(cache) == 0:
			self.init_cache(cache, **kwargs)
		
		meta_data = {}
		chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
		chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
		
		time1 = time.perf_counter()
		is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
		is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
		cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
		audio_sample_list = load_audio_text_image_video(data_in,
		                                                fs=frontend.fs,
		                                                audio_fs=kwargs.get("fs", 16000),
		                                                data_type=kwargs.get("data_type", "sound"),
		                                                tokenizer=tokenizer,
		                                                cache=cfg,
		                                                )
		_is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
		is_streaming_input = cfg["is_streaming_input"]
		time2 = time.perf_counter()
		meta_data["load_data"] = f"{time2 - time1:0.3f}"
		assert len(audio_sample_list) == 1, "batch_size must be set 1"
		
		audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
		
		n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
		m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
		segments = []
		for i in range(n):
			kwargs["is_final"] = _is_final and i == n - 1
			audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
			
			# extract fbank feats
			speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
			                                       frontend=frontend, cache=cache["frontend"],
			                                       is_final=kwargs["is_final"])
			time3 = time.perf_counter()
			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
			speech = speech.to(device=kwargs["device"])
			speech_lengths = speech_lengths.to(device=kwargs["device"])
			
			batch = {
				"feats": speech,
				"waveform": cache["frontend"]["waveforms"],
				"is_final": kwargs["is_final"],
				"cache": cache,
				"is_streaming_input": is_streaming_input
			}
			segments_i = self.forward(**batch)
			if len(segments_i) > 0:
				segments.extend(*segments_i)
		
		cache["prev_samples"] = audio_sample[:-m]
		if _is_final:
			self.init_cache(cache)
		
		ibest_writer = None
		if kwargs.get("output_dir") is not None:
			if not hasattr(self, "writer"):
				self.writer = DatadirWriter(kwargs.get("output_dir"))
			ibest_writer = self.writer[f"{1}best_recog"]
		
		results = []
		result_i = {"key": key[0], "value": segments}
		# if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
		# 	result_i = json.dumps(result_i)
		
		results.append(result_i)
		
		if ibest_writer is not None:
			ibest_writer["text"][key[0]] = segments
		
		return results, meta_data
	
	def export(self, **kwargs):

		from .export_meta import export_rebuild_model
		models = export_rebuild_model(model=self, **kwargs)
		return models

	def DetectCommonFrames(self, cache: dict = {}) -> int:
		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
			return 0
		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
			frame_state = FrameState.kFrameStateInvalid
			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
			                                 cache=cache)
			self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
		
		return 0
	
	def DetectLastFrames(self, cache: dict = {}) -> int:
		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
			return 0
		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
			frame_state = FrameState.kFrameStateInvalid
			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
			                                 cache=cache)
			if i != 0:
				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
			else:
				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
		
		return 0
	
	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
	                   cache: dict = {}) -> None:
		tmp_cur_frm_state = FrameState.kFrameStateInvalid
		if cur_frm_state == FrameState.kFrameStateSpeech:
			if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
				tmp_cur_frm_state = FrameState.kFrameStateSpeech
			else:
				tmp_cur_frm_state = FrameState.kFrameStateSil
		elif cur_frm_state == FrameState.kFrameStateSil:
			tmp_cur_frm_state = FrameState.kFrameStateSil
		state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
		frm_shift_in_ms = self.vad_opts.frame_in_ms
		if AudioChangeState.kChangeStateSil2Speech == state_change:
			silence_frame_count = cache["stats"].continous_silence_frame_count
			cache["stats"].continous_silence_frame_count = 0
			cache["stats"].pre_end_silence_detected = False
			start_frame = 0
			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
				start_frame = max(cache["stats"].data_buf_start_frame,
				                  cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
				self.OnVoiceStart(start_frame, cache=cache)
				cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
				for t in range(start_frame + 1, cur_frm_idx + 1):
					self.OnVoiceDetected(t, cache=cache)
			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
				for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
					self.OnVoiceDetected(t, cache=cache)
				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
				elif not is_final_frame:
					self.OnVoiceDetected(cur_frm_idx, cache=cache)
				else:
					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
			else:
				pass
		elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
			cache["stats"].continous_silence_frame_count = 0
			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
				pass
			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
				elif not is_final_frame:
					self.OnVoiceDetected(cur_frm_idx, cache=cache)
				else:
					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
			else:
				pass
		elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
			cache["stats"].continous_silence_frame_count = 0
			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
					cache["stats"].max_time_out = True
					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
				elif not is_final_frame:
					self.OnVoiceDetected(cur_frm_idx, cache=cache)
				else:
					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
			else:
				pass
		elif AudioChangeState.kChangeStateSil2Sil == state_change:
			cache["stats"].continous_silence_frame_count += 1
			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
				# silence timeout, return zero length decision
				if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
					cache[
						"stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
					or (is_final_frame and cache["stats"].number_end_time_detected == 0):
					for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
						self.OnSilenceDetected(t, cache=cache)
					self.OnVoiceStart(0, True, cache=cache)
					self.OnVoiceEnd(0, True, False, cache=cache)
					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
				else:
					if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
						self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
					"stats"].max_end_sil_frame_cnt_thresh:
					lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
					if self.vad_opts.do_extend:
						lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
						lookback_frame -= 1
						lookback_frame = max(0, lookback_frame)
					self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
				elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
				elif self.vad_opts.do_extend and not is_final_frame:
					if cache["stats"].continous_silence_frame_count <= int(
						self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
						self.OnVoiceDetected(cur_frm_idx, cache=cache)
				else:
					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
			else:
				pass
		
		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
			self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
			self.ResetDetection(cache=cache)


